import pandas as pd
import numpy as np

def load_data():
    """
    Load the Amazon Reviews dataset from the resources folder.
    :return: (x_train, y_train), (x_val, y_val), (x_test, y_test)
    """
    # Load the data
    train_data = pd.read_csv('resources/amazon_review_full_csv/train.csv', header=None, names=['rating', 'title', 'review'])
    test_data = pd.read_csv('resources/amazon_review_full_csv/test.csv', header=None, names=['rating', 'title', 'review'])

    # convert all the data to lower case
    train_data['review'] = train_data['review'].str.lower()
    test_data['review'] = test_data['review'].str.lower()

    # convert the ratings from integers of 1-5 to one-hot encodings using pandas get_dummies
    # train_data['rating'] = pd.get_dummies(train_data['rating'])
    # test_data['rating'] = pd.get_dummies(test_data['rating'])


    # take 5 percent of the training data as validation data and the rest as training data
    train_data = train_data.sample(frac=1, random_state=0)
    val_data = train_data[:int(len(train_data) * 0.1)]
    train_data = train_data[int(len(train_data) * 0.1):]

    # Style the data as follows: (x_train, y_train), (x_val, y_val), (x_test, y_test)
    x_train = train_data['review'].values
    y_train = pd.get_dummies(train_data['rating']).values
    x_val = val_data['review'].values
    y_val = pd.get_dummies(val_data['rating']).values
    x_test = test_data['review'].values
    y_test = pd.get_dummies(test_data['rating']).values

    return (x_train, y_train), (x_val, y_val), (x_test, y_test)


# Assuming `dataset` is a list of strings (your text data)
# words = set()
# (x_train, y_train), (x_val, y_val), (x_test, y_test) = load_data()
# for text in x_train:
#     # Splitting the text into words (this is a very basic way of doing it)
#     tokens = text.split()
#     words.update(tokens)
#
#
# vocab_size = len(words)
# print(f"Vocabulary size: {vocab_size}")

# print distribution of ratings (numpy array has no value_counts() method)
# # import numpy as np
# print(np.unique(y_train, return_counts=True))
# print(np.unique(y_val, return_counts=True))
# print(np.unique(y_test, return_counts=True))